{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "201501f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import os\n",
    "import pandas as pd\n",
    "from sklearn import preprocessing\n",
    "from tqdm import tqdm\n",
    "import fm\n",
    "import torch\n",
    "from torch import nn\n",
    "from torch import optim\n",
    "from torch.utils.data import DataLoader\n",
    "import numpy as np\n",
    "import random\n",
    "\n",
    "def seed_torch(seed=0):\n",
    "    random.seed(seed)\n",
    "    os.environ['PYTHONHASHSEED'] = str(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    torch.cuda.manual_seed(seed)\n",
    "    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.\n",
    "    torch.backends.cudnn.benchmark = False\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "\n",
    "seed_torch(2021)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "36614033",
   "metadata": {},
   "source": [
    "## 1. Load Model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "af3952d7",
   "metadata": {},
   "source": [
    "### (1) define utr_function_predictor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b5aee6b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Human5PrimeUTRPredictor(torch.nn.Module):\n",
    "    \"\"\"\n",
    "    contact predictor with inner product\n",
    "    \"\"\"\n",
    "    def __init__(self, alphabet=None, task=\"rgs\", arch=\"cnn\", input_types=[\"seq\", \"emb-rnafm\"]):\n",
    "        \"\"\"\n",
    "        :param depth_reduction: mean, first\n",
    "        \"\"\"       \n",
    "        super().__init__()     \n",
    "        self.alphabet = alphabet   # backbone alphabet: pad_idx=1, eos_idx=2, append_eos=True, prepend_bos=True\n",
    "        self.task = task\n",
    "        self.arch = arch\n",
    "        self.input_types = input_types        \n",
    "        self.padding_mode = \"right\"\n",
    "        self.token_len = 100\n",
    "        self.out_plane = 1\n",
    "        self.in_channels = 0\n",
    "        if \"seq\" in self.input_types:\n",
    "            self.in_channels = self.in_channels + 4\n",
    "\n",
    "        if \"emb-rnafm\" in self.input_types:\n",
    "            self.reductio_module = nn.Linear(640, 32)\n",
    "            self.in_channels = self.in_channels + 32  \n",
    "\n",
    "        if self.arch == \"cnn\" and self.in_channels != 0:\n",
    "            self.predictor = self.create_1dcnn_for_emd(in_planes=self.in_channels, out_planes=1)\n",
    "        else:\n",
    "            raise Exception(\"Wrong Arch Type\")\n",
    "\n",
    "    def forward(self, tokens, inputs):\n",
    "        ensemble_inputs = []\n",
    "        if \"seq\" in self.input_types:\n",
    "            # padding one-hot embedding            \n",
    "            nest_tokens = (tokens[:, 1:-1] - 4)   # covert token for RNA-FM (20 tokens) to nest version (4 tokens A,U,C,G)\n",
    "            nest_tokens = torch.nn.functional.pad(nest_tokens, (0, self.token_len - nest_tokens.shape[1]), value=-2)\n",
    "            token_padding_mask = nest_tokens.ge(0).long()\n",
    "            one_hot_tokens = torch.nn.functional.one_hot((nest_tokens * token_padding_mask), num_classes=4)\n",
    "            one_hot_tokens = one_hot_tokens.float() * token_padding_mask.unsqueeze(-1)            \n",
    "            # reserve padded one-hot embedding\n",
    "            one_hot_tokens = one_hot_tokens.permute(0, 2, 1)  # B, L, 4\n",
    "            ensemble_inputs.append(one_hot_tokens)\n",
    "\n",
    "        if \"emb-rnafm\" in self.input_types:\n",
    "            embeddings = inputs[\"emb-rnafm\"]\n",
    "            # padding RNA-FM embedding\n",
    "            embeddings, padding_masks = self.remove_pend_tokens_1d(tokens, embeddings)  # remove auxiliary tokens\n",
    "            batch_size, seqlen, hiddendim = embeddings.size()\n",
    "            embeddings = torch.nn.functional.pad(embeddings, (0, 0, 0, self.token_len - embeddings.shape[1]))            \n",
    "            # channel reduction\n",
    "            embeddings = self.reductio_module(embeddings)\n",
    "            # reserve padded RNA-FM embedding\n",
    "            embeddings = embeddings.permute(0, 2, 1)\n",
    "            ensemble_inputs.append(embeddings)        \n",
    "\n",
    "        ensemble_inputs = torch.cat(ensemble_inputs, dim=1)        \n",
    "        output = self.predictor(ensemble_inputs).squeeze(-1)\n",
    "        return output\n",
    " \n",
    "    def create_1dcnn_for_emd(self, in_planes, out_planes):\n",
    "        main_planes = 64\n",
    "        dropout = 0.2\n",
    "        emb_cnn = nn.Sequential(\n",
    "            nn.Conv1d(in_planes, main_planes, kernel_size=3, padding=1), \n",
    "            ResBlock(main_planes * 1, main_planes * 1, stride=2, dilation=1, conv_layer=nn.Conv1d,\n",
    "                     norm_layer=nn.BatchNorm1d), \n",
    "            ResBlock(main_planes * 1, main_planes * 1, stride=1, dilation=1, conv_layer=nn.Conv1d,\n",
    "                     norm_layer=nn.BatchNorm1d),  \n",
    "            ResBlock(main_planes * 1, main_planes * 1, stride=2, dilation=1, conv_layer=nn.Conv1d,\n",
    "                     norm_layer=nn.BatchNorm1d), \n",
    "            ResBlock(main_planes * 1, main_planes * 1, stride=1, dilation=1, conv_layer=nn.Conv1d,\n",
    "                     norm_layer=nn.BatchNorm1d),  \n",
    "            ResBlock(main_planes * 1, main_planes * 1, stride=2, dilation=1, conv_layer=nn.Conv1d,\n",
    "                     norm_layer=nn.BatchNorm1d), \n",
    "            ResBlock(main_planes * 1, main_planes * 1, stride=1, dilation=1, conv_layer=nn.Conv1d,\n",
    "                     norm_layer=nn.BatchNorm1d),       \n",
    "            nn.AdaptiveAvgPool1d(1),\n",
    "            nn.Flatten(),\n",
    "            nn.Dropout(dropout),\n",
    "            nn.Linear(main_planes * 1, out_planes),\n",
    "        )\n",
    "        return emb_cnn\n",
    "    \n",
    "    def remove_pend_tokens_1d(self, tokens, seqs):\n",
    "        padding_masks = tokens.ne(self.alphabet.padding_idx)\n",
    "\n",
    "        # remove eos token  （suffix first）\n",
    "        if self.alphabet.append_eos:     # default is right\n",
    "            eos_masks = tokens.ne(self.alphabet.eos_idx)\n",
    "            eos_pad_masks = (eos_masks & padding_masks).to(seqs)\n",
    "            seqs = seqs * eos_pad_masks.unsqueeze(-1)\n",
    "            seqs = seqs[:, ..., :-1, :]\n",
    "            padding_masks = padding_masks[:, ..., :-1]\n",
    "\n",
    "        # remove bos token\n",
    "        if self.alphabet.prepend_bos:    # default is left\n",
    "            seqs = seqs[:, ..., 1:, :]\n",
    "            padding_masks = padding_masks[:, ..., 1:]\n",
    "\n",
    "        if not padding_masks.any():\n",
    "            padding_masks = None\n",
    "\n",
    "        return seqs, padding_masks\n",
    "\n",
    "class ResBlock(nn.Module):\n",
    "    def __init__(\n",
    "        self,\n",
    "        in_planes,\n",
    "        out_planes,\n",
    "        stride=1,\n",
    "        dilation=1,\n",
    "        conv_layer=nn.Conv2d,\n",
    "        norm_layer=nn.BatchNorm2d,\n",
    "    ):\n",
    "        super(ResBlock, self).__init__()        \n",
    "        self.bn1 = norm_layer(in_planes)\n",
    "        self.relu1 = nn.ReLU(inplace=True)\n",
    "        self.conv1 = conv_layer(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, bias=False)       \n",
    "        self.bn2 = norm_layer(out_planes)\n",
    "        self.relu2 = nn.ReLU(inplace=True)\n",
    "        self.conv2 = conv_layer(out_planes, out_planes, kernel_size=3, padding=dilation, bias=False)\n",
    "\n",
    "        if stride > 1 or out_planes != in_planes: \n",
    "            self.downsample = nn.Sequential(\n",
    "                conv_layer(in_planes, out_planes, kernel_size=1, stride=stride, bias=False),\n",
    "                norm_layer(out_planes),\n",
    "            )\n",
    "        else:\n",
    "            self.downsample = None\n",
    "            \n",
    "        self.stride = stride\n",
    "\n",
    "    def forward(self, x):\n",
    "        identity = x\n",
    "        out = self.bn1(x)\n",
    "        out = self.relu1(out)\n",
    "        out = self.conv1(out)        \n",
    "        out = self.bn2(out)\n",
    "        out = self.relu2(out)\n",
    "        out = self.conv2(out)\n",
    "        if self.downsample is not None:\n",
    "            identity = self.downsample(x)\n",
    "        out += identity\n",
    "        return out\n",
    "    \n",
    "def weights_init(m):\n",
    "    classname = m.__class__.__name__\n",
    "    if classname.find('Linear') != -1:\n",
    "        nn.init.normal_(m.weight, std=0.001)\n",
    "        if isinstance(m.bias, nn.Parameter):\n",
    "            nn.init.constant_(m.bias, 0.0)\n",
    "    elif classname.find('Conv') != -1:\n",
    "        nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')\n",
    "        if m.bias is not None:\n",
    "            nn.init.constant_(m.bias, 0.0)\n",
    "    elif classname.find('BatchNorm') != -1:\n",
    "        if m.affine:\n",
    "            nn.init.constant_(m.weight, 1.0)\n",
    "            nn.init.constant_(m.bias, 0.0)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d28a72d6",
   "metadata": {},
   "source": [
    "### (2) create RNA-FM backbone"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e1551e78",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "create RNA-FM_backbone sucessfully\n"
     ]
    }
   ],
   "source": [
    "device=\"cuda\"   # \"cpu\"\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = \"0\"  \n",
    "backbone, alphabet = fm.pretrained.rna_fm_t12()\n",
    "backbone.to(device)\n",
    "print(\"create RNA-FM_backbone sucessfully\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5d07fc1e",
   "metadata": {},
   "source": [
    "### (3) create UTR function downstream predictor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "97442953",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "create utr_func_predictor sucessfully\n",
      "Human5PrimeUTRPredictor(\n",
      "  (reductio_module): Linear(in_features=640, out_features=32, bias=True)\n",
      "  (predictor): Sequential(\n",
      "    (0): Conv1d(32, 64, kernel_size=(3,), stride=(1,), padding=(1,))\n",
      "    (1): ResBlock(\n",
      "      (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu1): ReLU(inplace=True)\n",
      "      (conv1): Conv1d(64, 64, kernel_size=(3,), stride=(2,), padding=(1,), bias=False)\n",
      "      (bn2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu2): ReLU(inplace=True)\n",
      "      (conv2): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)\n",
      "      (downsample): Sequential(\n",
      "        (0): Conv1d(64, 64, kernel_size=(1,), stride=(2,), bias=False)\n",
      "        (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      )\n",
      "    )\n",
      "    (2): ResBlock(\n",
      "      (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu1): ReLU(inplace=True)\n",
      "      (conv1): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)\n",
      "      (bn2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu2): ReLU(inplace=True)\n",
      "      (conv2): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)\n",
      "    )\n",
      "    (3): ResBlock(\n",
      "      (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu1): ReLU(inplace=True)\n",
      "      (conv1): Conv1d(64, 64, kernel_size=(3,), stride=(2,), padding=(1,), bias=False)\n",
      "      (bn2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu2): ReLU(inplace=True)\n",
      "      (conv2): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)\n",
      "      (downsample): Sequential(\n",
      "        (0): Conv1d(64, 64, kernel_size=(1,), stride=(2,), bias=False)\n",
      "        (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      )\n",
      "    )\n",
      "    (4): ResBlock(\n",
      "      (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu1): ReLU(inplace=True)\n",
      "      (conv1): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)\n",
      "      (bn2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu2): ReLU(inplace=True)\n",
      "      (conv2): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)\n",
      "    )\n",
      "    (5): ResBlock(\n",
      "      (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu1): ReLU(inplace=True)\n",
      "      (conv1): Conv1d(64, 64, kernel_size=(3,), stride=(2,), padding=(1,), bias=False)\n",
      "      (bn2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu2): ReLU(inplace=True)\n",
      "      (conv2): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)\n",
      "      (downsample): Sequential(\n",
      "        (0): Conv1d(64, 64, kernel_size=(1,), stride=(2,), bias=False)\n",
      "        (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      )\n",
      "    )\n",
      "    (6): ResBlock(\n",
      "      (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu1): ReLU(inplace=True)\n",
      "      (conv1): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)\n",
      "      (bn2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu2): ReLU(inplace=True)\n",
      "      (conv2): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)\n",
      "    )\n",
      "    (7): AdaptiveAvgPool1d(output_size=1)\n",
      "    (8): Flatten(start_dim=1, end_dim=-1)\n",
      "    (9): Dropout(p=0.2, inplace=False)\n",
      "    (10): Linear(in_features=64, out_features=1, bias=True)\n",
      "  )\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "task=\"rgs\"\n",
    "arch=\"cnn\"\n",
    "input_items = [\"emb-rnafm\"]   # [\"seq\"], [\"emb-rnafm\"]\n",
    "model_name = arch.upper() + \"_\" + \"_\".join(input_items) \n",
    "utr_func_predictor = Human5PrimeUTRPredictor(\n",
    "    alphabet, task=task, arch=arch, input_types=input_items    \n",
    ")\n",
    "utr_func_predictor.apply(weights_init)\n",
    "utr_func_predictor.to(device)\n",
    "print(\"create utr_func_predictor sucessfully\")\n",
    "print(utr_func_predictor)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "be5e47f0",
   "metadata": {},
   "source": [
    "### (4) define loss function and optimizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "8d1208bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "criterion = nn.MSELoss(reduction=\"none\")\n",
    "optimizer = optim.Adam(utr_func_predictor.parameters(), lr=0.001)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9d2c6444",
   "metadata": {},
   "source": [
    "## 2. Load Data\n",
    "You should download data from gdrive link: https://drive.google.com/file/d/10zCfOHOaOa__J2AIuZyidZ9sVJ9L11rI/view?usp=sharing or https://proj.cse.cuhk.edu.hk/rnafm/api/download?filename=GSM4084997_varying_length_25to100.csv, and place it in the tutorials/utr-function-prediction/data. You can also use `wget https://proj.cse.cuhk.edu.hk/rnafm/api/download?filename=GSM4084997_varying_length_25to100.csv -O data/GSM4084997_varying_length_25to100.csv`"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2d679217",
   "metadata": {},
   "source": [
    "### (1) define utr dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "6be0f881",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Human_5Prime_UTR_VarLength(object):\n",
    "    def __init__(self, root, set_name=\"train\"):\n",
    "        \"\"\"\n",
    "        :param root: root path of dataset - CATH. however not all of stuffs under this root path\n",
    "        :param data_type: seq, msa\n",
    "        :param label_type: 1d, 2d\n",
    "        :param set_name: \"train\", \"valid\", \"test\"\n",
    "        \"\"\"\n",
    "        self.root = root\n",
    "        self.set_name = set_name\n",
    "        self.src_scv_path = os.path.join(self.root, \"data\", \"GSM4084997_varying_length_25to100.csv\") \n",
    "        self.seqs, self.scaled_rls = self.__dataset_info(self.src_scv_path)\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        seq_str = self.seqs[index].replace(\"T\", \"U\")\n",
    "        label = self.scaled_rls[index]\n",
    "\n",
    "        return seq_str, label\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.seqs)\n",
    "\n",
    "    def __dataset_info(self, src_csv_path):\n",
    "        # 1.Filter Data\n",
    "        # (1) Random Set\n",
    "        src_df = pd.read_csv(src_csv_path)\n",
    "        src_df.loc[:, \"ori_index\"] = src_df.index\n",
    "        random_df = src_df[src_df['set'] == 'random']\n",
    "        ## Filter out UTRs with too few less reads\n",
    "        random_df = random_df[random_df['total_reads'] >= 10]    # 87000 -> 83919             \n",
    "        random_df.sort_values('total_reads', inplace=True, ascending=False)\n",
    "        random_df.reset_index(inplace=True, drop=True)\n",
    "\n",
    "        # (2) Human Set\n",
    "        human_df = src_df[src_df['set'] == 'human']\n",
    "        ## Filter out UTRs with too few less reads\n",
    "        human_df = human_df[human_df['total_reads'] >= 10]   # 16739 -> 15555             \n",
    "        human_df.sort_values('total_reads', inplace=True, ascending=False)\n",
    "        human_df.reset_index(inplace=True, drop=True)       \n",
    "\n",
    "        # 2.Split Dataset\n",
    "        # (1) Generate Random Test set\n",
    "        random_df_test = pd.DataFrame(columns=random_df.columns)\n",
    "        for i in range(25, 101):\n",
    "            tmp = random_df[random_df['len'] == i].copy()\n",
    "            tmp.sort_values('total_reads', inplace=True, ascending=False)\n",
    "            tmp.reset_index(inplace=True, drop=True)\n",
    "            random_df_test = random_df_test.append(tmp.iloc[:100])\n",
    "        \n",
    "        # (2) Generate Human Test set\n",
    "        human_df_test = pd.DataFrame(columns=human_df.columns)\n",
    "        for i in range(25, 101):\n",
    "            tmp = human_df[human_df['len'] == i].copy()\n",
    "            tmp.sort_values('total_reads', inplace=True, ascending=False)\n",
    "            tmp.reset_index(inplace=True, drop=True)\n",
    "            human_df_test = human_df_test.append(tmp.iloc[:100])            \n",
    "        \n",
    "        # (3) Exclude Test data from Training data\n",
    "        train_df = pd.concat([random_df, random_df_test]).drop_duplicates(keep=False)  #  76319        \n",
    "        \n",
    "        # 3.Label Normalization (ribosome loading value)\n",
    "        label_col = 'rl'\n",
    "        self.scaler = preprocessing.StandardScaler()\n",
    "        self.scaler.fit(train_df.loc[:, label_col].values.reshape(-1, 1))\n",
    "        train_df.loc[:,'scaled_rl'] = self.scaler.transform(train_df.loc[:, label_col].values.reshape(-1, 1))\n",
    "        random_df_test.loc[:, 'scaled_rl'] = self.scaler.transform(random_df_test.loc[:, label_col].values.reshape(-1, 1))\n",
    "        human_df_test.loc[:, 'scaled_rl'] = self.scaler.transform(human_df_test.loc[:, label_col].values.reshape(-1, 1))\n",
    "\n",
    "        # 4.Pickup Target Set\n",
    "        if self.set_name == \"train\":\n",
    "            set_df = train_df\n",
    "        elif self.set_name == \"valid\":\n",
    "            set_df = random_df_test\n",
    "        else:\n",
    "            set_df = human_df_test \n",
    "        seqs = set_df['utr'].values\n",
    "        scaled_rls = set_df['scaled_rl'].values \n",
    "        names = set_df[\"ori_index\"].values       \n",
    "\n",
    "        print(\"Num samples of {} dataset: {} \".format(self.set_name, set_df[\"len\"].shape[0]))\n",
    "        return seqs, scaled_rls\n",
    "\n",
    "# covert tokens of different length to a batch tensor with the same length for each sample.\n",
    "def generate_token_batch(alphabet, seq_strs):\n",
    "    batch_size = len(seq_strs)\n",
    "    max_len = max(len(seq_str) for seq_str in seq_strs)\n",
    "    tokens = torch.empty(\n",
    "        (\n",
    "            batch_size,\n",
    "            max_len\n",
    "            + int(alphabet.prepend_bos)\n",
    "            + int(alphabet.append_eos),\n",
    "        ),\n",
    "        dtype=torch.int64,\n",
    "    )\n",
    "    tokens.fill_(alphabet.padding_idx)\n",
    "    for i, seq_str in enumerate(seq_strs):              \n",
    "        if alphabet.prepend_bos:\n",
    "            tokens[i, 0] = alphabet.cls_idx\n",
    "        seq = torch.tensor([alphabet.get_idx(s) for s in seq_str], dtype=torch.int64)\n",
    "        tokens[i, int(alphabet.prepend_bos): len(seq_str)+ int(alphabet.prepend_bos),] = seq\n",
    "        if alphabet.append_eos:\n",
    "            tokens[i, len(seq_str) + int(alphabet.prepend_bos)] = alphabet.eos_idx\n",
    "    return tokens\n",
    "    \n",
    "def collate_fn(batch):\n",
    "    seq_strs, labels = zip(*batch)\n",
    "    tokens = generate_token_batch(alphabet, seq_strs)\n",
    "    labels = torch.Tensor(labels)    \n",
    "    return seq_strs, tokens, labels    "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "690cf64c",
   "metadata": {},
   "source": [
    "### (2) generate dataloaders"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "7029caaa",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Num samples of train dataset: 76319 \n",
      "Num samples of valid dataset: 7600 \n",
      "Num samples of test dataset: 7600 \n"
     ]
    }
   ],
   "source": [
    "root_path = \"./\"\n",
    "train_set =  Human_5Prime_UTR_VarLength(root=root_path, set_name=\"train\",)\n",
    "val_set =  Human_5Prime_UTR_VarLength(root=root_path, set_name=\"valid\",)\n",
    "test_set =  Human_5Prime_UTR_VarLength(root=root_path, set_name=\"test\",)\n",
    "\n",
    "num_workers = 0\n",
    "batch_size = 64\n",
    "\n",
    "train_loader = DataLoader(\n",
    "    train_set, batch_size=batch_size, shuffle=True,\n",
    "    num_workers=num_workers, collate_fn=collate_fn, drop_last=False\n",
    ")\n",
    "\n",
    "val_loader = DataLoader(\n",
    "    val_set, batch_size=batch_size, shuffle=True,\n",
    "    num_workers=num_workers, collate_fn=collate_fn, drop_last=False\n",
    ")\n",
    "\n",
    "test_loader = DataLoader(\n",
    "    test_set, batch_size=batch_size, shuffle=True,\n",
    "    num_workers=num_workers, collate_fn=collate_fn, drop_last=False\n",
    ")\n",
    "\n",
    "scaler = train_set.scaler"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0b62c53c",
   "metadata": {},
   "source": [
    "## 3. Training Model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4469a0c7",
   "metadata": {},
   "source": [
    "### (1) define eval function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "d015847c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def model_eval(data_loader, i_epoch, set_name=\"unknown\"):\n",
    "    all_losses = []\n",
    "    true_rl_mses = []\n",
    "    for index, (seq_strs, tokens, labels) in enumerate(data_loader):\n",
    "        backbone.eval()\n",
    "        utr_func_predictor.eval()\n",
    "        tokens = tokens.to(device)\n",
    "        labels = labels.to(device)\n",
    "        with torch.no_grad():             \n",
    "            inputs = {}\n",
    "            results = {}\n",
    "            if \"emb-rnafm\" in input_items:\n",
    "                results = backbone(tokens, need_head_weights=False, repr_layers=[12], return_contacts=False)\n",
    "                inputs[\"emb-rnafm\"] = results[\"representations\"][12] \n",
    "            results[\"rl\"] = utr_func_predictor(tokens, inputs)        \n",
    "            losses = criterion(results[\"rl\"], labels)  \n",
    "            all_losses.append(losses.detach().cpu())    \n",
    "            \n",
    "            # true value metric\n",
    "            pds = scaler.inverse_transform(results[\"rl\"].detach().cpu().numpy())\n",
    "            gts = scaler.inverse_transform(labels.detach().cpu().numpy())\n",
    "            true_rl_mse = criterion(torch.Tensor(pds), torch.Tensor(gts))  \n",
    "            true_rl_mses.append(true_rl_mse.detach().cpu())  \n",
    "\n",
    "    avg_loss = torch.cat(all_losses, dim=0).mean()\n",
    "    avg_true_rl_mses = torch.cat(true_rl_mses, dim=0).mean()\n",
    "    print(\"Epoch {}, Evaluation on {} Set - MSE loss: {:.3f}\".format(i_epoch, set_name, avg_loss))\n",
    "    print(\"Epoch {}, Evaluation on {} Set - True MSE: {:.3f}\".format(i_epoch, set_name, avg_true_rl_mses))\n",
    "    \n",
    "    return avg_loss"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b3561221",
   "metadata": {},
   "source": [
    "### (2) training process"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33bd1c35",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 1, Train Set - MSE loss: 0.563: 100%|█████████████████████| 1193/1193 [03:37<00:00,  5.48it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1, Evaluation on Random Set - MSE loss: 0.433\n",
      "Epoch 1, Evaluation on Random Set - True MSE: 0.795\n",
      "--------- Model: CNN_emb-rnafm, Best Epoch 1, Best MSE 0.433\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 2, Train Set - MSE loss: 0.459: 100%|█████████████████████| 1193/1193 [03:36<00:00,  5.50it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2, Evaluation on Random Set - MSE loss: 0.398\n",
      "Epoch 2, Evaluation on Random Set - True MSE: 0.731\n",
      "--------- Model: CNN_emb-rnafm, Best Epoch 2, Best MSE 0.398\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 3, Train Set - MSE loss: 0.361: 100%|█████████████████████| 1193/1193 [03:35<00:00,  5.53it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 3, Evaluation on Random Set - MSE loss: 0.262\n",
      "Epoch 3, Evaluation on Random Set - True MSE: 0.482\n",
      "--------- Model: CNN_emb-rnafm, Best Epoch 3, Best MSE 0.262\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 4, Train Set - MSE loss: 0.305:  23%|█████                 | 274/1193 [00:47<02:37,  5.82it/s]"
     ]
    }
   ],
   "source": [
    "n_epoches = 10\n",
    "best_mse = 10\n",
    "best_epoch = 0\n",
    "\n",
    "for i_e in range(1, n_epoches+1):\n",
    "    all_losses = []\n",
    "    n_sample = 0\n",
    "    n_iter = len(train_loader)\n",
    "\n",
    "    pbar = tqdm(train_loader, desc=\"Epoch {}, Train Set - MSE loss: {}\".format(i_e, \"NaN\"), ncols=100)\n",
    "    for index, (seq_strs, tokens, labels) in enumerate(pbar):\n",
    "        backbone.eval()\n",
    "        utr_func_predictor.train()\n",
    "        tokens = tokens.to(device)\n",
    "        labels = labels.to(device)      \n",
    "        \n",
    "        inputs = {}\n",
    "        results = {}  \n",
    "        if \"emb-rnafm\" in  input_items:            \n",
    "            with torch.no_grad():\n",
    "                results = backbone(tokens, need_head_weights=False, repr_layers=[12], return_contacts=False)            \n",
    "            inputs[\"emb-rnafm\"] = results[\"representations\"][12]                \n",
    "        results[\"rl\"] = utr_func_predictor(tokens, inputs)\n",
    "        losses = criterion(results[\"rl\"], labels)\n",
    "        batch_loss = losses.mean()\n",
    "        batch_loss.backward()\n",
    "        optimizer.step()\n",
    "        optimizer.zero_grad()\n",
    "    \n",
    "        all_losses.append(losses.detach().cpu())\n",
    "        current_avg_loss = torch.cat(all_losses, dim=0).mean()\n",
    "        \n",
    "        pbar.set_description(\"Epoch {}, Train Set - MSE loss: {:.3f}\".format(i_e, current_avg_loss))\n",
    "    \n",
    "    random_mse = model_eval(val_loader, i_e, set_name=\"Random\")\n",
    "    \n",
    "    if random_mse < best_mse:\n",
    "        best_epoch = i_e\n",
    "        best_mse = random_mse\n",
    "        torch.save(utr_func_predictor.state_dict(), \"result/{}_best_utr_predictor.pth\".format(model_name))\n",
    "    print(\"--------- Model: {}, Best Epoch {}, Best MSE {:.3f}\".format(model_name, best_epoch, best_mse))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.11"
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "metadata": {
     "collapsed": false
    },
    "source": []
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
